Skip to content

[WIP] Add RTF autoencoder enrichment model#447

Open
KeshavMajithia wants to merge 4 commits into
boom-astro:mainfrom
KeshavMajithia:add-rtf-model
Open

[WIP] Add RTF autoencoder enrichment model#447
KeshavMajithia wants to merge 4 commits into
boom-astro:mainfrom
KeshavMajithia:add-rtf-model

Conversation

@KeshavMajithia

Copy link
Copy Markdown

Work in progress

Adds a Rust data loader for the RTF (Real-Time Filter) anomaly detection autoencoder, following the existing ACAI/BTSBot pattern.

What this does

  • rtf.rs: Standalone data loader that converts a ZTF alert into a (1, 257, 37) input tensor and runs ONNX inference via ort
  • Produces a 128-dim embedding vector and a scalar reconstruction error (anomaly score)
  • Reads full photometry history from prv_candidates and fp_hists
  • Extracts 30 metadata fields from the ZTF candidate, broadcast across all timesteps

ONNX model details

  • rtf_embed.onnx (8.4 MB): inputs (x, pad_mask) → 128-dim embedding
  • rtf_recon.onnx (13 MB): inputs (x, pad_mask) → scalar recon error

Not included in this PR

  • Integration into ZtfEnrichmentWorker

Adds a Rust data loader for the RTF (Real-Time Filter) anomaly detection
autoencoder. The model takes a ZTF alert's full photometry history,
builds a (1, 257, 37) input tensor, and runs inference through two ONNX
graphs to produce a 128-dim embedding and a scalar reconstruction error
(anomaly score).

Follows the existing ACAI/BTSBot pattern in src/enrichment/models/.
ONNX model files (rtf_embed.onnx, rtf_recon.onnx) to be hosted on
HuggingFace under the AppleCiDEr organization.
Copilot AI review requested due to automatic review settings April 23, 2026 16:53

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new enrichment-model loader for the RTF (Real-Time Filter) autoencoder so the codebase can build the model’s (1, 257, 37) input tensor from ZTF alert photometry + candidate metadata and run ONNX inference (embedding + reconstruction error), following the existing enrichment model patterns.

Changes:

  • Introduce RtfModel with input construction (build_input) and two inference paths (predict_embed, predict_recon) plus a combined predict_alert.
  • Add RtfOutput (embedding vector + anomaly score) for downstream consumption.
  • Export the new model module via src/enrichment/models/mod.rs.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/enrichment/models/rtf.rs Implements RTF preprocessing, padding mask creation, and ONNX inference for embedding + reconstruction error.
src/enrichment/models/mod.rs Registers and re-exports the new RTF model types.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +145 to +147
// logflux = -0.4 * magpsf (log10 flux in ZP=23.9 system)
x[[0, i, 2]] = -0.4 * mag;
// logflux_err = 0.4 * sigmapsf

Copilot AI Apr 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log-flux calculation/documentation is internally inconsistent: the comment says “log10 flux in ZP=23.9 system”, but the code uses -0.4 * mag (missing the + 0.4 * ZTF_ZP offset implied by that statement). Please either (a) update the computation to match the ZP=23.9 convention used elsewhere in the codebase (see mag2flux in src/utils/lightcurves.rs), or (b) adjust the comment/docs to reflect the actual convention the ONNX model was trained on so inference isn’t silently mismatched.

Suggested change
// logflux = -0.4 * magpsf (log10 flux in ZP=23.9 system)
x[[0, i, 2]] = -0.4 * mag;
// logflux_err = 0.4 * sigmapsf
// Model input uses the training-time convention directly:
// logflux = -0.4 * magpsf and logflux_err = 0.4 * sigmapsf.
// This is a magnitude-derived transform, not a ZP=23.9 flux conversion.
x[[0, i, 2]] = -0.4 * mag;

Copilot uses AI. Check for mistakes.
Comment on lines +83 to +95
// Current candidate
let band_idx = band_to_idx(&alert.candidate.band);
detections.push((candidate.jd, candidate.magpsf, candidate.sigmapsf, band_idx));

// Previous candidates
for phot in &alert.prv_candidates {
if let (Some(mag), Some(sig)) = (phot.magpsf, phot.sigmapsf) {
let mag = mag as f32;
let sig = sig as f32;
let idx = band_to_idx(&phot.band);
detections.push((phot.jd, mag, sig, idx));
}
}

Copilot AI Apr 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_input unconditionally pushes the current candidate into detections, but ZTF_alerts_aux.prv_candidates is already populated with the current candidate during ingestion (see src/alert/ztf.rs where the current candidate is appended if missing, then sanitized/deduped by jd). Since ZtfAlertForEnrichment.prv_candidates is loaded from aux, this likely duplicates the latest detection and can skew dt_prev/band one-hot/features. Consider either (a) only adding the current candidate if it’s not already present (by jd/candid), or (b) deduplicating/sanitizing the combined detections list after collection.

Copilot uses AI. Check for mistakes.
Adds a Rust data loader for the TEMPO transient/variable classifier
(EvidentialClassifier with Dirichlet uncertainty). The model classifies
ZTF alerts into 5 classes (SNI, SNII, TDE, AGN, CV).

Key differences from RTF:
- Simpler 5-dim input tensor (no alert metadata, no images)
- Drops i-band observations per training config
- Normalizes continuous features with pre-computed mean/std
- Computes 24 physics summary features as a separate model input
- Outputs class probabilities via evidence -> alpha -> Dirichlet mean

ONNX model file (tempo_classifier.onnx) to be hosted on HuggingFace.
- Add DEFAULT_UNCERTAINTY_THRESHOLD (0.25) derived from validation
  (correct=0.192, incorrect=0.260 mean uncertainty)
- Add passes_threshold bool and uncertainty_threshold fields to TempoOutput
- Add predict_alert_with_threshold() for custom threshold overrides
- Rejecting top ~20%% uncertain predictions pushes accuracy to ~100%%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants